// STL
#include <cmath>     // std::log(), std::exp(), std::sqrt()
#include <stdexcept> // std::runtime_error()
#include <string>    // std::to_string()
#include <vector>    // std::vector<>

// jScience
#include "jScience/linalg.hpp"         // Matrix<>, Vector<>, dot_product(), transpose()
#include "jScience/physics/consts.hpp" // PI


// NOTE: for energy function and free energy formulas, see: http://wiki.statsxx.com/view/Restricted_Boltzmann_machine

// NOTE: in each of the following, constant terms to the free energy are not calculated (as they are not informative)

// NOTE: the softplus free energy is approximated as the ReLU free energy (the former requires complicated series expansions for the antiderivative)


static double free_energy_visible(
                                  const int            vtype,
                                  // -----
                                  const Vector<double> &a,
                                  // -----
                                  const Vector<double> &v
                                  );

static double free_energy_hidden(
                                 const int            htype,
                                 // -----
                                 const Matrix<double> &W,
                                 const Vector<double> &b,
                                 // -----
                                 const Vector<double> &v
                                 );

static double free_energy_const(
                                const int            htype,
                                // -----
                                const int            nh
                                );


double RBM_free_energy(
                       const int             vtype, // 0 == binary; 1 == continuous; 2 == ReLU; 3 == softplus
                       const int             htype, // ... same
                       // -----
                       const int             nv,
                       const int             nh,
                       // -----
                       const Matrix<double> &W,
                       const Vector<double> &a,
                       const Vector<double> &b,
                       // -----
                       const Vector<double> &v
                       )
{
    double visible_term = free_energy_visible(
                                              vtype,
                                              // -----
                                              a,
                                              // -----
                                              v
                                              );
    
    double hidden_term = free_energy_hidden(
                                            htype,
                                            // -----
                                            W,
                                            b,
                                            // -----
                                            v
                                            );
    
    // NOTE: as discussed above, the constant term is not calculated (included)
    
//    double const_term = free_energy_const(
//                                          htype,
//                                          // -----
//                                          nh
//                                          );
    
    return (visible_term + hidden_term);
}


//
// DESC:
//
std::vector<double> RBM_free_energy(
                            const int             vtype,
                            const int             htype,
                            // -----
                            const int             nv,
                            const int             nh,
                            // -----
                            const Matrix<double> &W,
                            const Vector<double> &a,
                            const Vector<double> &b,
                            // -----
                            const Matrix<double> &V
                            )
{
    std::vector<double> free_energies;
    
    for(auto i = 0; i < V.size(0); ++i)
    {
        double free_energy = RBM_free_energy(
                                             vtype, 
                                             htype,
                                             // -----
                                             nv,
                                             nh,
                                             // -----
                                             W,
                                             a,
                                             b,
                                             // -----
                                             V.row(i)
                                             );
        
        free_energies.push_back(free_energy);
    }
    
    return free_energies;
}


//
// DESC:
//
static double free_energy_visible(
                                  const int            vtype,
                                  // -----
                                  const Vector<double> &a,
                                  // -----
                                  const Vector<double> &v
                                  )
{
    double free_energy = -dot_product(
                                      a,
                                      v
                                      );
    
    switch(vtype)
    {
        case 0:
            break;
        case 1:
            free_energy += 0.5*dot_product(
                                           v,
                                           v
                                           );
            break;
        case 2:
        case 3:
            for(auto i = 0; i < v.size(); ++i)
            {
                if(v(i) > 0.)
                {
                    free_energy += v(i)*v(i);
                }
            }
            break;
        default:
            throw std::runtime_error("error: RBM type vtype=" + std::to_string(vtype) + " not recognized (or implemented) in free_energy_visible()");
            break;
    }
    
    return free_energy;
}

//
// DESC:
//
static double free_energy_hidden(
                                 const int            htype,
                                 // -----
                                 const Matrix<double> &W,
                                 const Vector<double> &b,
                                 // -----
                                 const Vector<double> &v
                                 )
{
    double free_energy;
    
    Vector<double> x = transpose(W)*v + b;
    
    switch(htype)
    {
        case 0:
            // NOTE: repeated multiplication followed by one call to std::log() should be faster than repeated addition of calls to std::log()
            free_energy = 1.;
            for(auto i = 0; i < x.size(); ++i)
            {
                free_energy *= (1. + std::exp(x(i)));
            }
            free_energy = -std::log(free_energy);
            break;
        case 1:
            free_energy = -0.5*dot_product(
                                           x,
                                           x
                                           );
            break;
        case 2:
        case 3:
            free_energy = -0.25*dot_product(
                                            x,
                                            x
                                            );
            break;
        default:
            throw std::runtime_error("error: RBM type htype=" + std::to_string(htype) + " not recognized (or implemented) in free_energy_hidden()");
            break;
    }
    
    return free_energy;
}

//
// DESC:
//
static double free_energy_const(
                                const int            htype,
                                // -----
                                const int            nh
                                )
{
    double free_energy;
    
    switch(htype)
    {
        case 0:
            free_energy = 0.;
            break;
        case 1:
            free_energy = -nh*std::log(2*std::sqrt(PI/2));
            break;
        case 2:
        case 3:
            free_energy = -nh*std::log(PI/2);
            break;
        default:
            throw std::runtime_error("error: RBM type htype=" + std::to_string(htype) + " not recognized (or implemented) in free_energy_const()");
            break;
    }
    
    return free_energy;
}